/*
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <gtest/gtest.h>

#include "velox/common/base/tests/GTestUtils.h"
#include "velox/core/PlanNode.h"
#include "velox/vector/fuzzer/VectorFuzzer.h"
#include "velox/vector/tests/utils/VectorTestBase.h"

using namespace ::facebook::velox;
using namespace ::facebook::velox::core;

namespace {
class PlanNodeTest : public testing::Test, public test::VectorTestBase {
 protected:
  static void SetUpTestCase() {
    memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{});
  }

  PlanNodeTest() {
    rowType_ = ROW({"c0", "c1", "c2"}, {BIGINT(), BIGINT(), BIGINT()});

    VectorFuzzer::Options opts;
    VectorFuzzer fuzzer(opts, pool_.get());
    rowData_.push_back(fuzzer.fuzzInputRow(rowType_));
  }

  RowTypePtr rowType_;
  std::vector<RowVectorPtr> rowData_;
};

TEST_F(PlanNodeTest, findFirstNode) {
  auto rowType = ROW({"name1"}, {BIGINT()});

  std::shared_ptr<connector::ConnectorTableHandle> tableHandle;
  connector::ColumnHandleMap assignments;

  std::shared_ptr<PlanNode> tableScan3 =
      std::make_shared<TableScanNode>("3", rowType, tableHandle, assignments);
  std::shared_ptr<PlanNode> tableScan2 =
      std::make_shared<TableScanNode>("2", rowType, tableHandle, assignments);

  std::vector<FieldAccessTypedExprPtr> sortingKeys;
  std::vector<SortOrder> sortingOrders;
  std::shared_ptr<PlanNode> localMerge1 = std::make_shared<LocalMergeNode>(
      "1",
      sortingKeys,
      sortingOrders,
      std::vector<PlanNodePtr>{tableScan2, tableScan3});

  std::vector<std::string> names;
  std::vector<TypedExprPtr> projections;
  std::shared_ptr<PlanNode> project0 =
      std::make_shared<ProjectNode>("0", names, projections, localMerge1);

  EXPECT_EQ(
      tableScan3.get(),
      PlanNode::findFirstNode(project0.get(), [](const PlanNode* node) {
        return node->id() == "3";
      }));

  EXPECT_EQ(
      project0.get(),
      PlanNode::findFirstNode(project0.get(), [](const PlanNode* node) {
        return node->name() == "Project";
      }));

  EXPECT_EQ(
      nullptr,
      PlanNode::findFirstNode(project0.get(), [](const PlanNode* node) {
        return node->name() == "Unknown";
      }));
}

TEST_F(PlanNodeTest, findNodeById) {
  auto values = std::make_shared<ValuesNode>("1", std::vector<RowVectorPtr>{});
  auto project = std::make_shared<ProjectNode>(
      "2",
      std::vector<std::string>{"a", "b"},
      std::vector<TypedExprPtr>{
          std::make_shared<CallTypedExpr>(DOUBLE(), "rand"),
          std::make_shared<CallTypedExpr>(DOUBLE(), "rand"),
      },
      values);

  auto filter = std::make_shared<FilterNode>(
      "3",
      std::make_shared<CallTypedExpr>(
          BOOLEAN(),
          "gt",
          std::make_shared<FieldAccessTypedExpr>(DOUBLE(), "a"),
          std::make_shared<ConstantTypedExpr>(DOUBLE(), 0.5)),
      project);

  auto limit = std::make_shared<LimitNode>("4", 0, 10, false, filter);

  ASSERT_EQ(PlanNode::findNodeById(limit.get(), "1"), values.get());
  ASSERT_EQ(PlanNode::findNodeById(limit.get(), "2"), project.get());
  ASSERT_EQ(PlanNode::findNodeById(limit.get(), "3"), filter.get());
  ASSERT_EQ(PlanNode::findNodeById(limit.get(), "4"), limit.get());

  ASSERT_EQ(PlanNode::findNodeById(limit.get(), "5"), nullptr);
  ASSERT_EQ(PlanNode::findNodeById(project.get(), "4"), nullptr);
}

TEST_F(PlanNodeTest, is) {
  auto values = std::make_shared<ValuesNode>("1", std::vector<RowVectorPtr>{});
  auto project = std::make_shared<ProjectNode>(
      "2",
      std::vector<std::string>{"a", "b"},
      std::vector<TypedExprPtr>{
          std::make_shared<CallTypedExpr>(DOUBLE(), "rand"),
          std::make_shared<CallTypedExpr>(DOUBLE(), "rand"),
      },
      values);

  ASSERT_TRUE(values->is<ValuesNode>());
  ASSERT_FALSE(values->is<ProjectNode>());

  ASSERT_FALSE(project->is<ValuesNode>());
  ASSERT_TRUE(project->is<ProjectNode>());
}

TEST_F(PlanNodeTest, sortOrder) {
  struct {
    SortOrder order1;
    SortOrder order2;
    int expectedEqual;

    std::string debugString() const {
      return fmt::format(
          "order1 {} order2 {} expectedEqual {}",
          order1.toString(),
          order2.toString(),
          expectedEqual);
    }
  } testSettings[] = {
      {{true, true}, {true, true}, true},
      {{true, false}, {true, false}, true},
      {{false, true}, {false, true}, true},
      {{false, false}, {false, false}, true},
      {{true, true}, {true, false}, false},
      {{true, true}, {false, false}, false},
      {{true, true}, {false, true}, false},
      {{true, false}, {false, false}, false},
      {{true, false}, {false, true}, false},
      {{false, true}, {false, false}, false}};
  for (const auto& testData : testSettings) {
    SCOPED_TRACE(testData.debugString());
    if (testData.expectedEqual) {
      ASSERT_EQ(testData.order1, testData.order2);
    } else {
      ASSERT_NE(testData.order1, testData.order2);
    }
  }
}

TEST_F(PlanNodeTest, duplicateSortKeys) {
  auto sortingKeys = std::vector<FieldAccessTypedExprPtr>{
      std::make_shared<core::FieldAccessTypedExpr>(BIGINT(), "c0"),
      std::make_shared<core::FieldAccessTypedExpr>(BIGINT(), "c1"),
      std::make_shared<core::FieldAccessTypedExpr>(BIGINT(), "c0"),
  };
  auto sortingOrders =
      std::vector<SortOrder>{{true, true}, {false, false}, {true, true}};
  VELOX_ASSERT_USER_THROW(
      std::make_shared<OrderByNode>(
          "orderBy", sortingKeys, sortingOrders, false, nullptr),
      "Duplicate sorting keys are not allowed: c0");
}

class TestIndexTableHandle : public connector::ConnectorTableHandle {
 public:
  TestIndexTableHandle()
      : connector::ConnectorTableHandle("TestIndexConnnector") {}

  ~TestIndexTableHandle() override = default;

  std::string toString() const override {
    return "TestIndexTableHandle";
  }

  const std::string& name() const override {
    static const std::string kName = "TestIndexTableHandle";
    return kName;
  }

  bool supportsIndexLookup() const override {
    return true;
  }

  folly::dynamic serialize() const override {
    return {};
  }

  static std::shared_ptr<TestIndexTableHandle> create(
      const folly::dynamic& obj,
      void* context) {
    return std::make_shared<TestIndexTableHandle>();
  }
};

TEST_F(PlanNodeTest, indexLookupJoin) {
  const auto rowType = ROW({"name"}, {BIGINT()});
  const auto valueNode = std::make_shared<ValuesNode>("orderBy", rowData_);
  ASSERT_FALSE(isIndexLookupJoin(valueNode.get()));

  const RowTypePtr probeType = ROW({"c0"}, {BIGINT()});
  const RowTypePtr buildType = ROW({"c1"}, {BIGINT()});
  const RowTypePtr outputType = ROW({"c0", "c1"}, {BIGINT(), BIGINT()});
  auto indexTableHandle = std::make_shared<TestIndexTableHandle>();
  const auto probeNode = std::make_shared<TableScanNode>(
      "tableScan-probe", probeType, nullptr, connector::ColumnHandleMap{});
  ASSERT_FALSE(isIndexLookupJoin(probeNode.get()));
  const auto buildNode = std::make_shared<TableScanNode>(
      "tableScan-build",
      buildType,
      indexTableHandle,
      connector::ColumnHandleMap{});
  ASSERT_FALSE(isIndexLookupJoin(buildNode.get()));
  const std::vector<FieldAccessTypedExprPtr> leftKeys{
      std::make_shared<FieldAccessTypedExpr>(BIGINT(), "c0")};
  const std::vector<FieldAccessTypedExprPtr> rightKeys{
      std::make_shared<FieldAccessTypedExpr>(BIGINT(), "c1")};
  {
    const auto indexJoinNodeWithInnerJoin =
        std::make_shared<IndexLookupJoinNode>(
            "indexJoinNode",
            core::JoinType::kInner,
            leftKeys,
            rightKeys,
            std::vector<IndexLookupConditionPtr>{},
            /*filter=*/nullptr,
            /*hasMarker=*/false,
            probeNode,
            buildNode,
            outputType);
    ASSERT_TRUE(isIndexLookupJoin(indexJoinNodeWithInnerJoin.get()));
    ASSERT_FALSE(indexJoinNodeWithInnerJoin->hasMarker());
    ASSERT_EQ(indexJoinNodeWithInnerJoin->filter(), nullptr);
    ASSERT_EQ(
        indexJoinNodeWithInnerJoin->toString(/*detailed=*/true),
        "-- IndexLookupJoin[indexJoinNode][INNER c0=c1] -> c0:BIGINT, c1:BIGINT\n");
  }
  {
    const RowTypePtr outputTypeWithMatchColumn =
        ROW({"c0", "c1", "c2"}, {BIGINT(), BIGINT(), BOOLEAN()});
    const auto indexJoinNodeWithLeftJoin =
        std::make_shared<IndexLookupJoinNode>(
            "indexJoinNode",
            core::JoinType::kLeft,
            leftKeys,
            rightKeys,
            std::vector<IndexLookupConditionPtr>{},
            /*filter=*/nullptr,
            /*hasMarker=*/true,
            probeNode,
            buildNode,
            outputTypeWithMatchColumn);
    ASSERT_TRUE(isIndexLookupJoin(indexJoinNodeWithLeftJoin.get()));
    ASSERT_TRUE(indexJoinNodeWithLeftJoin->hasMarker());
    ASSERT_EQ(indexJoinNodeWithLeftJoin->filter(), nullptr);
    ASSERT_EQ(
        indexJoinNodeWithLeftJoin->toString(/*detailed=*/true),
        "-- IndexLookupJoin[indexJoinNode][LEFT c0=c1] -> c0:BIGINT, c1:BIGINT, c2:BOOLEAN\n");
  }
  {
    // Test IndexLookupJoinNode with filter
    const auto filterExpr = std::make_shared<core::FieldAccessTypedExpr>(
        BOOLEAN(), "filter_column");
    const auto indexJoinNodeWithFilter = std::make_shared<IndexLookupJoinNode>(
        "indexJoinNodeWithFilter",
        core::JoinType::kInner,
        leftKeys,
        rightKeys,
        std::vector<IndexLookupConditionPtr>{},
        /*filter=*/filterExpr,
        /*hasMarker=*/false,
        probeNode,
        buildNode,
        outputType);
    ASSERT_TRUE(isIndexLookupJoin(indexJoinNodeWithFilter.get()));
    ASSERT_FALSE(indexJoinNodeWithFilter->hasMarker());
    ASSERT_EQ(indexJoinNodeWithFilter->filter(), filterExpr);
    ASSERT_EQ(
        indexJoinNodeWithFilter->toString(/*detailed=*/true),
        "-- IndexLookupJoin[indexJoinNodeWithFilter][INNER c0=c1, filter: \"filter_column\"] -> c0:BIGINT, c1:BIGINT\n");
  }
  // Error case.
  {
    VELOX_ASSERT_THROW(
        std::make_shared<IndexLookupJoinNode>(
            "indexJoinNode",
            core::JoinType::kInner,
            leftKeys,
            rightKeys,
            std::vector<IndexLookupConditionPtr>{},
            /*filter=*/nullptr,
            /*hasMarker=*/true,
            probeNode,
            buildNode,
            outputType),
        "Index join match column can only present for LEFT but not INNER");
  }
  {
    VELOX_ASSERT_THROW(
        std::make_shared<IndexLookupJoinNode>(
            "indexJoinNode",
            core::JoinType::kLeft,
            leftKeys,
            rightKeys,
            std::vector<IndexLookupConditionPtr>{},
            /*filter=*/nullptr,
            /*hasMarker=*/true,
            probeNode,
            buildNode,
            outputType),
        "The last output column must be boolean type if match column is present");
  }
  {
    const RowTypePtr outputTypeWithDuplicateMatchColumn =
        ROW({"c0", "c1", "c0"}, {BIGINT(), BIGINT(), BOOLEAN()});
    VELOX_ASSERT_THROW(
        std::make_shared<IndexLookupJoinNode>(
            "indexJoinNode",
            core::JoinType::kLeft,
            leftKeys,
            rightKeys,
            std::vector<IndexLookupConditionPtr>{},
            /*filter=*/nullptr,
            /*hasMarker=*/true,
            probeNode,
            buildNode,
            outputTypeWithDuplicateMatchColumn),
        "");
  }
}

TEST_F(PlanNodeTest, partitionedOutputNode) {
  const PlanNodeId id{"partitionedOutputNode"};
  const PartitionedOutputNode::Kind kind =
      PartitionedOutputNode::Kind::kPartitioned;
  const std::vector<TypedExprPtr> keys = {
      std::make_shared<FieldAccessTypedExpr>(BIGINT(), "c0")};
  const PartitionFunctionSpecPtr partitionFunctionSpec =
      std::make_shared<GatherPartitionFunctionSpec>();
  const VectorSerde::Kind serdeKind = VectorSerde::Kind::kPresto;
  PlanNodePtr source = std::make_shared<ValuesNode>("source", rowData_);

  {
    // Creating a PartitionedOutputNode with a single partition, empty keys, and
    // a null partition function should succeed.
    PartitionedOutputNode node(
        id,
        kind,
        {},
        1, // numPartitions
        true, // replicateNullsAndAny
        nullptr, // partitionFunctionSpec
        rowType_,
        serdeKind,
        source);
    // Attempting to dereference the nullptr should fail.
    ASSERT_EQ(node.partitionFunctionSpecPtr(), nullptr);
    VELOX_ASSERT_THROW(node.partitionFunctionSpec(), "");
  }

  // Creating a PartitionedOutputNode that is not partitioned and has empty keys
  // and a partition function (even kinds other than partitioned still use a
  // partition function) should succeed.
  {
    PartitionedOutputNode node(
        id,
        PartitionedOutputNode::Kind::kArbitrary,
        {},
        10, // numPartitions
        true, // replicateNullsAndAny
        partitionFunctionSpec,
        rowType_,
        serdeKind,
        source);
    // We should be able to dereference the partition function spec.
    ASSERT_EQ(node.partitionFunctionSpecPtr(), partitionFunctionSpec);
    ASSERT_EQ(
        node.partitionFunctionSpec().toString(),
        partitionFunctionSpec->toString());
  }

  // Creating a PartitionedOutputNode with numPartitions = 0 should throw.
  VELOX_ASSERT_THROW(
      PartitionedOutputNode(
          id,
          kind,
          keys,
          0, // numPartitions
          true, // replicateNullsAndAny
          partitionFunctionSpec,
          rowType_,
          serdeKind,
          source),
      "");

  // Creating a PartitionedOutputNode with numPartitions = 1 and non-empty
  // keys should throw.
  VELOX_ASSERT_THROW(
      PartitionedOutputNode(
          id,
          kind,
          keys,
          1, // numPartitions
          true, // replicateNullsAndAny
          partitionFunctionSpec,
          rowType_,
          serdeKind,
          source),
      "Non-empty partitioning keys require more than one partition");

  // Creating a PartitionedOutputNode with numPartitions > 1 and no partition
  // function should throw.
  VELOX_ASSERT_THROW(
      PartitionedOutputNode(
          id,
          kind,
          keys,
          5, // numPartitions
          true, // replicateNullsAndAny
          nullptr, // partitionFunctionSpec
          rowType_,
          serdeKind,
          source),
      "Partition function spec must be specified when the number of destinations is more than 1.");

  // Creating a PartitionedOutputNode that is not partitioned with non-empty
  // keys should throw.
  VELOX_ASSERT_THROW(
      PartitionedOutputNode(
          id,
          PartitionedOutputNode::Kind::kArbitrary,
          keys,
          5, // numPartitions
          true, // replicateNullsAndAny
          partitionFunctionSpec,
          rowType_,
          serdeKind,
          source),
      "partitioning doesn't allow for partitioning keys");
}

TEST_F(PlanNodeTest, aggregationNodeNoGroupsSpanBatches) {
  auto values = std::make_shared<ValuesNode>("values", rowData_);

  const std::vector<FieldAccessTypedExprPtr> groupingKeys{
      std::make_shared<FieldAccessTypedExpr>(BIGINT(), "c0")};
  const std::vector<FieldAccessTypedExprPtr> preGroupedKeys{
      std::make_shared<FieldAccessTypedExpr>(BIGINT(), "c0")};
  const std::vector<std::string> aggregateNames{"sum"};
  const std::vector<AggregationNode::Aggregate> aggregates{
      {.call = std::make_shared<CallTypedExpr>(BIGINT(), "sum"),
       .rawInputTypes = {BIGINT()}}};

  // noGroupsSpanBatches=true with preGroupedKeys (streaming aggregation) should
  // succeed and the accessor should return true.
  {
    auto aggNode = std::make_shared<AggregationNode>(
        "agg",
        AggregationNode::Step::kSingle,
        groupingKeys,
        preGroupedKeys,
        aggregateNames,
        aggregates,
        /*ignoreNullKeys=*/false,
        /*noGroupsSpanBatches=*/true,
        values);
    ASSERT_TRUE(aggNode->noGroupsSpanBatches());
    ASSERT_TRUE(aggNode->isPreGrouped());
    ASSERT_EQ(
        aggNode->toString(true),
        "-- Aggregation[agg][SINGLE STREAMING [c0] sum := sum() noGroupsSpanBatches] -> c0:BIGINT, sum:BIGINT\n");
  }

  // noGroupsSpanBatches=false with preGroupedKeys should succeed and the
  // accessor should return false.
  {
    auto aggNode = std::make_shared<AggregationNode>(
        "agg",
        AggregationNode::Step::kSingle,
        groupingKeys,
        preGroupedKeys,
        aggregateNames,
        aggregates,
        /*ignoreNullKeys=*/false,
        /*noGroupsSpanBatches=*/false,
        values);
    ASSERT_FALSE(aggNode->noGroupsSpanBatches());
    ASSERT_TRUE(aggNode->isPreGrouped());
    ASSERT_EQ(
        aggNode->toString(true),
        "-- Aggregation[agg][SINGLE STREAMING [c0] sum := sum()] -> c0:BIGINT, sum:BIGINT\n");
  }

  // noGroupsSpanBatches=true without preGroupedKeys (non-streaming aggregation)
  // should fail.
  VELOX_ASSERT_THROW(
      std::make_shared<AggregationNode>(
          "agg",
          AggregationNode::Step::kSingle,
          groupingKeys,
          /*preGroupedKeys=*/std::vector<FieldAccessTypedExprPtr>{},
          aggregateNames,
          aggregates,
          /*ignoreNullKeys=*/false,
          /*noGroupsSpanBatches=*/true,
          values),
      "noGroupsSpanBatches can only be set for streaming aggregation (pre-grouped)");

  // noGroupsSpanBatches=false without preGroupedKeys should succeed.
  {
    auto aggNode = std::make_shared<AggregationNode>(
        "agg",
        AggregationNode::Step::kSingle,
        groupingKeys,
        /*preGroupedKeys=*/std::vector<FieldAccessTypedExprPtr>{},
        aggregateNames,
        aggregates,
        /*ignoreNullKeys=*/false,
        /*noGroupsSpanBatches=*/false,
        values);
    ASSERT_FALSE(aggNode->noGroupsSpanBatches());
    ASSERT_FALSE(aggNode->isPreGrouped());
    ASSERT_EQ(
        aggNode->toString(true),
        "-- Aggregation[agg][SINGLE [c0] sum := sum()] -> c0:BIGINT, sum:BIGINT\n");
  }
}
} // namespace
